import functools
import gym
import numpy as np
import scipy.linalg as la
from gym.envs.registration import register
from scipy import stats
from scipy.sparse.csgraph import connected_components
from scipy.special import expit, logit
import pandas as pd

def generate_orthogonal_matrix(n, rng):
    A = rng.normal(0, 1, (n, n))
    Q, R = np.linalg.qr(A)
    return Q


def get_communicating_classes(matrix):
    n_components, labels = connected_components(
        csgraph=matrix, directed=True, return_labels=True
    )
    classes = [[] for _ in range(n_components)]
    for index, label in enumerate(labels):
        classes[label].append(index)
    return classes


def calculate_stationary_distribution(matrix, communicating_classes, even_class_distribution=False):
    n = matrix.shape[0]
    full_stationary_distribution = np.zeros(n)
    class_probability = (1 / len(communicating_classes) if even_class_distribution else None)

    for states in communicating_classes:
        submatrix = matrix[states][:, states]
        num_states_in_class = len(submatrix)
        A = np.vstack(
            (submatrix.T - np.eye(num_states_in_class), np.ones(num_states_in_class))
        )
        b = np.zeros(num_states_in_class + 1)
        b[-1] = 1
        pi, _, _, _ = la.lstsq(A, b)

        if even_class_distribution:
            weight = class_probability
        else:
            weight = num_states_in_class / n

        for i, state in enumerate(states):
            full_stationary_distribution[state] = pi[i] * weight

    full_stationary_distribution /= full_stationary_distribution.sum()
    return full_stationary_distribution


class EpiCare(gym.Env):
    """Environment to model disease treatment using RL."""

    def __init__(
        self,
        n_diseases=16,
        n_treatments=16,
        n_symptoms=8,
        remission_reward=64,
        adverse_event_reward=-64,
        adverse_event_threshold=0.999,
        max_visits=8,
        seed=1,
    ):
        super(EpiCare, self).__init__()
        rng = np.random.RandomState(seed)

        self.n_diseases = n_diseases
        self.n_treatments = n_treatments
        self.n_symptoms = n_symptoms
        self.remission_reward = remission_reward
        self.adverse_event_reward = adverse_event_reward
        self.adverse_event_threshold = adverse_event_threshold
        self.max_visits = max_visits
        self.connection_probability = 1 / n_diseases

        self.action_space = gym.spaces.Discrete(n_treatments)
        self.observation_space = gym.spaces.Box(
            low=0.0, high=1.0, shape=(n_symptoms,), dtype=np.float32
        )

        self.symptom_reward_multiplier = remission_reward / (2 * max_visits * n_symptoms)

        self.generate_diseases(rng)
        self.disease_list = list(self.diseases.keys())
        self.treatments = self.generate_treatments(rng)
        self.generate_transition_matrix(rng)
        self.compute_stationary_distribution()
        self.reset()

    def generate_transition_matrix(self, rng):
        self.transition_matrix = np.eye(self.n_diseases)
        for i in range(self.n_diseases):
            for j in range(self.n_diseases):
                if i != j:
                    if rng.rand() < self.connection_probability:
                        self.transition_matrix[i][j] = rng.uniform(0.01, 0.2)
                        self.transition_matrix[j][i] = rng.uniform(0.01, 0.2)
        row_sums = (self.transition_matrix.sum(axis=1) - 1)
        self.transition_matrix[np.arange(self.n_diseases), np.arange(self.n_diseases)] = (1 - row_sums)

    def compute_stationary_distribution(self, even_class_distribution=False):
        self.communicating_classes = get_communicating_classes(self.transition_matrix)
        self.stationary_distribution = calculate_stationary_distribution(
            self.transition_matrix, self.communicating_classes, even_class_distribution
        )

    def generate_diseases(self, rng):
        self.diseases = {}
        for i in range(self.n_diseases):
            means = rng.uniform(0.0, 2.0, size=self.n_symptoms)
            std_devs = rng.uniform(1.0, 2.0, size=self.n_symptoms)
            std_devs_sorted = np.sort(std_devs)[::-1]
            P = generate_orthogonal_matrix(self.n_symptoms, rng)
            Sigma = P @ np.diag(std_devs_sorted**2) @ P.T
            self.diseases[f"Disease_{i}"] = {
                "symptom_means": means,
                "symptom_covariances": Sigma,
                "remission_probs": dict(),
                "base_cost": 0,
                "treatments": [],
            }
        for treatment in range(self.n_treatments):
            disease = f"Disease_{rng.randint(0, self.n_diseases)}"
            self.diseases[disease]["treatments"].append(treatment)
            self.diseases[disease]["remission_probs"][treatment] = rng.uniform(0.8, 1.0)

    def sample_symptoms(self):
        if self.current_disease == "Remission":
            return np.random.uniform(0.0, 0.1, self.n_symptoms)
        mu = self.diseases[self.current_disease]["symptom_means"]
        cov = self.diseases[self.current_disease]["symptom_covariances"]
        values = np.random.multivariate_normal(mu, cov)
        return expit(values)

    def generate_treatments(self, rng):
        treatments = {}
        for i in range(self.n_treatments):
            base_cost = rng.uniform(1, 5)
            effects = np.zeros(self.n_symptoms)
            effects[rng.choice(self.n_symptoms, size=2, replace=False)] = rng.uniform(-1, 1, size=2)
            modifiers = rng.uniform(0.5, 1.5, size=self.n_diseases)
            treatments[f"Treatment_{i}"] = dict(
                base_cost=base_cost,
                treatment_effects=effects,
                transition_modifiers=modifiers,
            )
        return treatments

    def reset(self, *, seed=None, options=None):
        self.current_disease = np.random.choice(self.disease_list, p=self.stationary_distribution)
        self.current_disease_index = self.disease_list.index(self.current_disease)
        self.current_symptoms = self.sample_symptoms()
        self.visit_number = 0
        self.stage = 0
        return self.current_symptoms

    def step(self, action):
        self.visit_number += 1
        self.stage += 1
        reward = 0

        treatment = self.treatments[f"Treatment_{action}"]
        reward -= treatment["base_cost"]

        if np.random.rand() < self.diseases[self.current_disease]["remission_probs"].get(action, 0):
            self.current_disease = "Remission"
            self.current_symptoms = self.sample_symptoms()
            reward += self.remission_reward
            done, delta = True, 1
            return self.current_symptoms, reward, done, {"delta": delta, "stage": self.stage}

        self.current_symptoms = self.sample_symptoms()
        adverse_event = self.current_symptoms.max() > self.adverse_event_threshold
        if adverse_event:
            reward += self.adverse_event_reward
            done, delta = True, 1
        else:
            done = self.visit_number == self.max_visits
            delta = 0 if done else 1

        return self.current_symptoms, reward, done, {"delta": delta, "stage": self.stage}


register(id="EpiCare-v0", entry_point="__main__:EpiCare")